{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dynamic Recurrent Neural Network.\n",
    "\n",
    "TensorFlow 2.0 implementation of a Recurrent Neural Network (LSTM) that performs dynamic computation over sequences with variable length. This example is using a toy dataset to classify linear sequences. The generated sequences have variable length.\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN Overview\n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png\" alt=\"nn\" style=\"width: 600px;\"/>\n",
    "\n",
    "References:\n",
    "- [Long Short Term Memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf), Sepp Hochreiter & Jurgen Schmidhuber, Neural Computation 9(8): 1735-1780, 1997."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "# Import TensorFlow v2.\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import Model, layers\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset parameters.\n",
    "num_classes = 2 # linear sequence or not.\n",
    "seq_max_len = 20 # Maximum sequence length.\n",
    "seq_min_len = 5 # Minimum sequence length (before padding).\n",
    "masking_val = -1 # -1 will represents the mask and be used to pad sequences to a common max length.\n",
    "max_value = 10000 # Maximum int value.\n",
    "\n",
    "# Training Parameters\n",
    "learning_rate = 0.001\n",
    "training_steps = 2000\n",
    "batch_size = 64\n",
    "display_step = 100\n",
    "\n",
    "# Network Parameters\n",
    "num_units = 32 # number of neurons for the LSTM layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ====================\n",
    "#  TOY DATA GENERATOR\n",
    "# ====================\n",
    "\n",
    "def toy_sequence_data():\n",
    "    \"\"\" Generate sequence of data with dynamic length.\n",
    "    This function generates toy samples for training:\n",
    "    - Class 0: linear sequences (i.e. [1, 2, 3, 4, ...])\n",
    "    - Class 1: random sequences (i.e. [9, 3, 10, 7,...])\n",
    "\n",
    "    NOTICE:\n",
    "    We have to pad each sequence to reach 'seq_max_len' for TensorFlow\n",
    "    consistency (we cannot feed a numpy array with inconsistent\n",
    "    dimensions). The dynamic calculation will then be perform and ignore\n",
    "    the masked value (here -1).\n",
    "    \"\"\"\n",
    "    while True:\n",
    "        # Set variable sequence length.\n",
    "        seq_len = random.randint(seq_min_len, seq_max_len)\n",
    "        rand_start = random.randint(0, max_value - seq_len)\n",
    "        # Add a random or linear int sequence (50% prob).\n",
    "        if random.random() < .5:\n",
    "            # Generate a linear sequence.\n",
    "            seq = np.arange(start=rand_start, stop=rand_start+seq_len)\n",
    "            # Rescale values to [0., 1.].\n",
    "            seq = seq / max_value\n",
    "            # Pad sequence until the maximum length for dimension consistency.\n",
    "            # Masking value: -1.\n",
    "            seq = np.pad(seq, mode='constant', pad_width=(0, seq_max_len-seq_len), constant_values=masking_val)\n",
    "            label = 0\n",
    "        else:\n",
    "            # Generate a random sequence.\n",
    "            seq = np.random.randint(max_value, size=seq_len)\n",
    "            # Rescale values to [0., 1.].\n",
    "            seq = seq / max_value\n",
    "            # Pad sequence until the maximum length for dimension consistency.\n",
    "            # Masking value: -1.\n",
    "            seq = np.pad(seq, mode='constant', pad_width=(0, seq_max_len-seq_len), constant_values=masking_val)\n",
    "            label = 1\n",
    "        yield np.array(seq, dtype=np.float32), np.array(label, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use tf.data API to shuffle and batch data.\n",
    "train_data = tf.data.Dataset.from_generator(toy_sequence_data, output_types=(tf.float32, tf.float32))\n",
    "train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create LSTM Model.\n",
    "class LSTM(Model):\n",
    "    # Set layers.\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        # Define a Masking Layer with -1 as mask.\n",
    "        self.masking = layers.Masking(mask_value=masking_val)\n",
    "        # Define a LSTM layer to be applied over the Masking layer.\n",
    "        # Dynamic computation will automatically be performed to ignore -1 values.\n",
    "        self.lstm = layers.LSTM(units=num_units)\n",
    "        # Output fully connected layer (2 classes: linear or random seq).\n",
    "        self.out = layers.Dense(num_classes)\n",
    "\n",
    "    # Set forward pass.\n",
    "    def call(self, x, is_training=False):\n",
    "        # A RNN Layer expects a 3-dim input (batch_size, seq_len, num_features).\n",
    "        x = tf.reshape(x, shape=[-1, seq_max_len, 1])\n",
    "        # Apply Masking layer.\n",
    "        x = self.masking(x)\n",
    "        # Apply LSTM layer.\n",
    "        x = self.lstm(x)\n",
    "        # Apply output layer.\n",
    "        x = self.out(x)\n",
    "        if not is_training:\n",
    "            # tf cross entropy expect logits without softmax, so only\n",
    "            # apply softmax when not training.\n",
    "            x = tf.nn.softmax(x)\n",
    "        return x\n",
    "\n",
    "# Build LSTM model.\n",
    "lstm_net = LSTM()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cross-Entropy Loss.\n",
    "# Note that this will apply 'softmax' to the logits.\n",
    "def cross_entropy_loss(x, y):\n",
    "    # Convert labels to int 64 for tf cross-entropy function.\n",
    "    y = tf.cast(y, tf.int64)\n",
    "    # Apply softmax to logits and compute cross-entropy.\n",
    "    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n",
    "    # Average loss across the batch.\n",
    "    return tf.reduce_mean(loss)\n",
    "\n",
    "# Accuracy metric.\n",
    "def accuracy(y_pred, y_true):\n",
    "    # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n",
    "    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n",
    "    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n",
    "\n",
    "# Adam optimizer.\n",
    "optimizer = tf.optimizers.Adam(learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimization process. \n",
    "def run_optimization(x, y):\n",
    "    # Wrap computation inside a GradientTape for automatic differentiation.\n",
    "    with tf.GradientTape() as g:\n",
    "        # Forward pass.\n",
    "        pred = lstm_net(x, is_training=True)\n",
    "        # Compute loss.\n",
    "        loss = cross_entropy_loss(pred, y)\n",
    "        \n",
    "    # Variables to update, i.e. trainable variables.\n",
    "    trainable_variables = lstm_net.trainable_variables\n",
    "\n",
    "    # Compute gradients.\n",
    "    gradients = g.gradient(loss, trainable_variables)\n",
    "    \n",
    "    # Update weights following gradients.\n",
    "    optimizer.apply_gradients(zip(gradients, trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 1, loss: 0.695558, accuracy: 0.562500\n",
      "step: 100, loss: 0.664089, accuracy: 0.609375\n",
      "step: 200, loss: 0.408653, accuracy: 0.812500\n",
      "step: 300, loss: 0.417392, accuracy: 0.828125\n",
      "step: 400, loss: 0.495420, accuracy: 0.765625\n",
      "step: 500, loss: 0.524736, accuracy: 0.703125\n",
      "step: 600, loss: 0.401653, accuracy: 0.859375\n",
      "step: 700, loss: 0.315812, accuracy: 0.906250\n",
      "step: 800, loss: 0.394490, accuracy: 0.828125\n",
      "step: 900, loss: 0.327425, accuracy: 0.875000\n",
      "step: 1000, loss: 0.312831, accuracy: 0.843750\n",
      "step: 1100, loss: 0.251562, accuracy: 0.875000\n",
      "step: 1200, loss: 0.192276, accuracy: 0.906250\n",
      "step: 1300, loss: 0.173289, accuracy: 0.906250\n",
      "step: 1400, loss: 0.159411, accuracy: 0.937500\n",
      "step: 1500, loss: 0.138854, accuracy: 0.921875\n",
      "step: 1600, loss: 0.046906, accuracy: 0.984375\n",
      "step: 1700, loss: 0.121232, accuracy: 0.937500\n",
      "step: 1800, loss: 0.067761, accuracy: 1.000000\n",
      "step: 1900, loss: 0.134532, accuracy: 0.968750\n",
      "step: 2000, loss: 0.090837, accuracy: 0.953125\n"
     ]
    }
   ],
   "source": [
    "# Run training for the given number of steps.\n",
    "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n",
    "    # Run the optimization to update W and b values.\n",
    "    run_optimization(batch_x, batch_y)\n",
    "    \n",
    "    if step % display_step == 0 or step == 1:\n",
    "        pred = lstm_net(batch_x, is_training=True)\n",
    "        loss = cross_entropy_loss(pred, batch_y)\n",
    "        acc = accuracy(pred, batch_y)\n",
    "        print(\"step: %i, loss: %f, accuracy: %f\" % (step, loss, acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
